import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from openpyxl import load_workbook
import start

# %%

RESULTS_FILE = start.MAIN_DIR + "results/performance_characters_test.xlsx"

CLASSIFICATION_FILE = (
    start.MAIN_DIR + "data/clean/gpt_classifications_characters_test.xlsx"
)
ANNOTATIONS_FILE = start.MAIN_DIR + "data/clean/character_classifications_gold.xlsx"

gold = pd.read_excel(ANNOTATIONS_FILE)
gold = gold[["unique_id", "character_gold", "set"]]


gpt = pd.read_excel(
    start.MAIN_DIR + "data/clean/gpt_classifications_characters_test.xlsx",
)
gpt = gpt[["unique_id", "response"]]

df = gpt.merge(gold, on="unique_id", how="inner")
df["gold_standard"] = df.character_gold
df["gpt_classification"] = df.response
df = df[df.set == "test"]
df["agree"] = np.where(df.gold_standard == df.gpt_classification, 1, 0)


# %%
def bootstrap_accuracy(y_true, y_pred, n_bootstraps=1000, random_state=12):
    rng = np.random.default_rng(seed=random_state)
    n = len(y_true)
    scores = []

    for _ in range(n_bootstraps):
        indices = rng.integers(0, n, n)
        y_true_boot = np.array(y_true)[indices]
        y_pred_boot = np.array(y_pred)[indices]
        scores.append(accuracy_score(y_true_boot, y_pred_boot))

    scores = np.array(scores)
    return scores.mean(), scores.std(ddof=1)


def bootstrap_precision(y_true, y_pred, n_bootstraps=1000, random_state=12):
    rng = np.random.default_rng(seed=random_state)
    n = len(y_true)
    scores = []

    for _ in range(n_bootstraps):
        indices = rng.integers(0, n, n)
        y_true_boot = np.array(y_true)[indices]
        y_pred_boot = np.array(y_pred)[indices]
        scores.append(precision_score(y_true_boot, y_pred_boot))

    scores = np.array(scores)
    return scores.mean(), scores.std(ddof=1)


def bootstrap_recall(y_true, y_pred, n_bootstraps=1000, random_state=12):
    rng = np.random.default_rng(seed=random_state)
    n = len(y_true)
    scores = []

    for _ in range(n_bootstraps):
        indices = rng.integers(0, n, n)
        y_true_boot = np.array(y_true)[indices]
        y_pred_boot = np.array(y_pred)[indices]
        scores.append(recall_score(y_true_boot, y_pred_boot))

    scores = np.array(scores)
    return scores.mean(), scores.std(ddof=1)


def bootstrap_f1(y_true, y_pred, n_bootstraps=1000, random_state=12):
    rng = np.random.default_rng(seed=random_state)
    n = len(y_true)
    scores = []

    for _ in range(n_bootstraps):
        indices = rng.integers(0, n, n)
        y_true_boot = np.array(y_true)[indices]
        y_pred_boot = np.array(y_pred)[indices]
        scores.append(f1_score(y_true_boot, y_pred_boot))

    scores = np.array(scores)
    return scores.mean(), scores.std(ddof=1)


# Load the workbook and get the sheet names
wb = load_workbook(CLASSIFICATION_FILE)
sheet_names = wb.sheetnames

wb = load_workbook(RESULTS_FILE)

row = 2
start_result_col = 2

for character in ["hero", "villain", "victim"]:

    df["gold_" + character] = np.where(df.gold_standard == character, 1, 0)
    df["gpt_" + character] = np.where(df.gpt_classification == character, 1, 0)

    accuracy = accuracy_score(df["gold_" + character], df["gpt_" + character])
    precision = precision_score(df["gold_" + character], df["gpt_" + character])
    recall = recall_score(df["gold_" + character], df["gpt_" + character])
    f1 = f1_score(df["gold_" + character], df["gpt_" + character])

    _, accuracy_se = bootstrap_accuracy(df["gold_" + character], df["gpt_" + character])
    _, precision_se = bootstrap_precision(
        df["gold_" + character], df["gpt_" + character]
    )
    _, recall_se = bootstrap_recall(df["gold_" + character], df["gpt_" + character])
    _, f1_se = bootstrap_f1(df["gold_" + character], df["gpt_" + character])

    ws = wb["results"]

    ws.cell(row=row, column=1, value=character)
    col = start_result_col
    for metric in [accuracy, precision, recall, f1]:
        ws.cell(row=row, column=col, value=metric.round(2))
        col += 1

    row = row + 1
    col = start_result_col
    for metric in [accuracy_se, precision_se, recall_se, f1_se]:
        ws.cell(row=row, column=col, value=f"({metric.round(2)})")
        col += 1
    row = row + 1
wb.save(RESULTS_FILE)

# %%
